{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "958524a2-cb56-439e-850e-032dd10478f2",
   "metadata": {},
   "source": [
    "# Sampling from a diffusion model\n",
    "\n",
    "<!--- @wandbcode{dlai_03} -->\n",
    "\n",
    "In this notebook we will sample from the previously trained diffusion model.\n",
    "- We are going to compare the samples from DDPM and DDIM samplers\n",
    "- Visualize mixing samples with conditional diffusion models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "700e687c",
   "metadata": {
    "height": 148,
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from types import SimpleNamespace\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from utilities import *\n",
    "\n",
    "import wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcaf7a29-782c-4735-991f-4408f5ec6128",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "wandb.login(anonymous=\"allow\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c0d229a",
   "metadata": {},
   "source": [
    "# Setting Things Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54c3a942",
   "metadata": {
    "height": 335,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Wandb Params\n",
    "MODEL_ARTIFACT = \"dlai-course/model-registry/SpriteGen:latest\"\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "config = SimpleNamespace(\n",
    "    # hyperparameters\n",
    "    num_samples = 30,\n",
    "    \n",
    "    # ddpm sampler hyperparameters\n",
    "    timesteps = 500,\n",
    "    beta1 = 1e-4,\n",
    "    beta2 = 0.02,\n",
    "    \n",
    "    # ddim sampler hp\n",
    "    ddim_n = 25,\n",
    "    \n",
    "    # network hyperparameters\n",
    "    height = 16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb43f98f",
   "metadata": {},
   "source": [
    "In the previous notebook we saved the best model as a wandb Artifact (our way of storing files during runs). We will now load the model from wandb and set up the sampling loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ab66255",
   "metadata": {
    "height": 454
   },
   "outputs": [],
   "source": [
    "def load_model(model_artifact_name):\n",
    "    \"Load the model from wandb artifacts\"\n",
    "    api = wandb.Api()\n",
    "    artifact = api.artifact(model_artifact_name, type=\"model\")\n",
    "    model_path = Path(artifact.download())\n",
    "\n",
    "    # recover model info from the registry\n",
    "    producer_run = artifact.logged_by()\n",
    "\n",
    "    # load the weights dictionary\n",
    "    model_weights = torch.load(model_path/\"context_model.pth\", \n",
    "                               map_location=\"cpu\")\n",
    "\n",
    "    # create the model\n",
    "    model = ContextUnet(in_channels=3, \n",
    "                        n_feat=producer_run.config[\"n_feat\"], \n",
    "                        n_cfeat=producer_run.config[\"n_cfeat\"], \n",
    "                        height=producer_run.config[\"height\"])\n",
    "    \n",
    "    # load the weights into the model\n",
    "    model.load_state_dict(model_weights)\n",
    "\n",
    "    # set the model to eval mode\n",
    "    model.eval()\n",
    "    return model.to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b47633e2",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "nn_model = load_model(MODEL_ARTIFACT)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe8eb277",
   "metadata": {},
   "source": [
    "## Sampling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45d92c52-8a11-450c-bc78-ffa221af2fa3",
   "metadata": {},
   "source": [
    "We will sample and log the generated samples to wandb."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "146424d3",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "_, sample_ddpm_context = setup_ddpm(config.beta1, \n",
    "                                    config.beta2, \n",
    "                                    config.timesteps, \n",
    "                                    DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00b9ef16-1848-476d-a9dd-09175b8f0e3c",
   "metadata": {},
   "source": [
    "Let's define a set of noises and a context vector to condition on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d88afdba",
   "metadata": {
    "height": 233,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Noise vector\n",
    "# x_T ~ N(0, 1), sample initial noise\n",
    "noises = torch.randn(config.num_samples, 3, \n",
    "                     config.height, config.height).to(DEVICE)  \n",
    "\n",
    "# A fixed context vector to sample from\n",
    "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0,   # hero\n",
    "                                     1,1,1,1,1,1,   # non-hero\n",
    "                                     2,2,2,2,2,2,   # food\n",
    "                                     3,3,3,3,3,3,   # spell\n",
    "                                     4,4,4,4,4,4]), # side-facing \n",
    "                       5).to(DEVICE).float()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cbf9ef8-619a-4052-a138-a88c0f0f8b0b",
   "metadata": {},
   "source": [
    "Let's bring that faster DDIM sampler from the diffusion course."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c1a945d",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "sample_ddim_context = setup_ddim(config.beta1, \n",
    "                                 config.beta2, \n",
    "                                 config.timesteps, \n",
    "                                 DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90b838be-8fa1-4c12-9c4f-e40dfacc08e1",
   "metadata": {},
   "source": [
    "### Sampling:\n",
    "let's compute ddpm samples as before"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e24210-4885-4559-92e1-db10566ef5ea",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "ddpm_samples, _ = sample_ddpm_context(nn_model, noises, ctx_vector)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "836584a1-26b5-45b1-98c9-0c45d639c8f9",
   "metadata": {},
   "source": [
    "For DDIM we can control the step size by the `n` param:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25b07c26-0ac2-428a-8351-34f8b7228074",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "ddim_samples, _ = sample_ddim_context(nn_model, \n",
    "                                      noises, \n",
    "                                      ctx_vector, \n",
    "                                      n=config.ddim_n)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5db3cb01",
   "metadata": {},
   "source": [
    "### Visualizing generations on a Table\n",
    "Let's create a `wandb.Table` to store our generations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f1d3b94",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "table = wandb.Table(columns=[\"input_noise\", \"ddpm\", \"ddim\", \"class\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85be303d-0f0b-4df4-8c87-bd1bfb6145a2",
   "metadata": {},
   "source": [
    "We can add the rows to the table one by one, we also cast images to `wandb.Image` so we can render them correctly in the UI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "481afea1-ae53-4b5b-a3db-1d49be0733a3",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "for noise, ddpm_s, ddim_s, c in zip(noises, \n",
    "                                    ddpm_samples, \n",
    "                                    ddim_samples, \n",
    "                                    to_classes(ctx_vector)):\n",
    "    \n",
    "    # add data row by row to the Table\n",
    "    table.add_data(wandb.Image(noise),\n",
    "                   wandb.Image(ddpm_s), \n",
    "                   wandb.Image(ddim_s),\n",
    "                   c)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "987cee86-2db1-4a2a-9d14-f70c6248ecb9",
   "metadata": {},
   "source": [
    "we log the table to W&B, we can also use `wandb.init` as a context manager, this way we ensure that the run is finished when exiting the manager."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbc7a2ca-ae05-4462-9ae3-82eb1a6dbc27",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "with wandb.init(project=\"dlai_sprite_diffusion\", \n",
    "                job_type=\"samplers_battle\", \n",
    "                config=config):\n",
    "    \n",
    "    wandb.log({\"samplers_table\":table})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36cde325-5a53-45c9-ac57-6b52553d00d1",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
